import copy

S = ['ABC', 'AB', 'BC', 'AC', 'A', 'B', 'C', '']
options = ('A', 'B', 'C')

class MDP(object):
    """A Markov Decision Process."""
    def __init__(self, R, mu, lmb, delay, gamma):
        super(MDP, self).__init__()
        self.R = R
        self.mu = mu
        self.lmb = lmb
        self.delay = delay
        self.maxh = max(delay.values())
        self.gamma = gamma

    def actions(self, s, w):
        """Set of actions that can be performed in state s at time w."""
        lst = ['']

        if w == 0:
            for o in options:
                if o in s:
                    lst.append(o)

        return lst

    def singlePr(self, s, opt):
        """
        Return the probability that an option will be available in the next
        timestep, given the current state s.
        """
        if opt in s:   return 1 - self.mu[opt]
        else:          return self.lmb[opt]

    def transPr(self, s0, s1):
        """
        Return the probability to transition to state s' (s1) from state s (s0), 
        based on the probability of every individual option to be available in the 
        next timestep.
        """
        p = 1
        for o in options:
            if o in s1:  p *= self.singlePr(s0, o)
            else:        p *= 1 - self.singlePr(s0, o)
        return p

    def delayPr(self, s0, w0, a, s1, w1):
        """
        Return probability to transition from state (s, w) to state (s', w') in 
        the next timestep, after taking action a.
        """
        p = 0

        # (s, w) = ({}, 0)
        if s0 == '' and w0 == 0:
            if w1 == 0:
                p = self.transPr(s0, s1)

        # (s, w) = (s, 0)
        elif w0 == 0:
            if w1 == self.delay[a] - 1:
                p = self.transPr(s0, s1)

        # (s, w) = (s, w)
        else:
            if w1 == w0 - 1:
                p = self.transPr(s0, s1)

        return p

    def valIter(self, maxIter = 500, epsilon = 1e-8):
        """
        Solve MDP by value iteration with a maximum number of iterations and 
        a given tolerance epsilon. Return optimal values and iteration number.
        """
        V1 = {s: [0.] * self.maxh for s in S}
        
        i = 0
        delta = 1.

        if self.gamma == 1:    threshold = epsilon
        else:                  threshold = epsilon * (1 - self.gamma) / self.gamma

        while i < maxIter and delta > threshold:
            i += 1
            mn = 5.
            Mn = 0.

            V = copy.deepcopy(V1)

            # Bellman equation
            for s in S:
                for w in range(self.maxh):
                    V1[s][w] = max([self.R[a] + 
                               sum([self.delayPr(s, w, a, s1, w1) * self.gamma * V[s1][w1] for s1 in S 
                                                                              for w1 in range(self.maxh)] )
                                                                              for a in self.actions(s, w)] )
                    # calculate span semi-norm
                    mn = min(mn, V1[s][w] - V[s][w])
                    Mn = max(Mn, V1[s][w] - V[s][w])
                    delta = Mn - mn

        return V, i

    def expectancy(self, s, a, V):
        """
        The expected utility of doing a in state s at a decision epoch (w = 0), 
        according to V.
        """
        return self.R[a] + sum([self.delayPr(s, 0, a, s1, w1) * V[s1][w1] for s1 in S 
                                                                for w1 in range(self.maxh)] )

    def optiPi(self, V):
        """
        Determine the optimal policy from the optimal values V. Actions taken 
        at decision epochs, i.e. w = 0.
        """
        pi = {}
        for s in S:
            temp = {}
            for a in self.actions(s, 0):
                temp[a] = self.expectancy(s, a, V)
            pi[s] = max(temp, key = lambda x: temp[x])
        return pi


def printPi(pi):
    """
    Display policy in pretty format. Skips missing state-to-action 
    mappings.
    """
    for s in S:
        if s in pi.keys():  print s, '->', pi[s]
        else:               print s, '-> ?'


# Usage example
case = MDP(
    R = {'A': 5, 'B': 2, 'C': 1, '': 0}, 
    mu = {'A': .5, 'B': .5, 'C': .5}, 
    lmb = {'A': .01, 'B': .5, 'C': .01}, 
    delay = {'A': 6, 'B': 2, 'C': 1, '': 0},
    gamma = 1
    )

V, i = case.valIter(500)
pi = case.optiPi(V)
printPi(pi)
